TRPO (Trust Region Policy Optimization) — low-level PyTorch implementation#
TRPO is an on-policy policy-gradient method that makes monotonic-ish, stable updates by constraining how much the policy is allowed to change each iteration via a KL-divergence trust region.
In this notebook you will:
Derive the KL constraint (LaTeX) and how it leads to a natural-gradient step
Implement TRPO “from scratch” with PyTorch autograd + conjugate gradient + backtracking line search
Visualize policy updates, KL per update, and episodic returns with Plotly
See a reference Stable-Baselines TRPO implementation and understand its hyperparameters
Notebook roadmap#
TRPO objective + the KL-divergence constraint (math)
A tiny offline-friendly continuous-control environment (no downloads)
Gaussian policy + value baseline (PyTorch)
GAE advantages + value function fit
TRPO update step (Fisher-vector product, conjugate gradient, line search)
Plotly: episodic rewards, KL constraint, policy update snapshots
Stable-Baselines TRPO: usage + hyperparameters (end)
import sys
import time
import numpy as np
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots
import torch
import torch.nn as nn
import torch.nn.functional as F
pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)
DEVICE = torch.device("cpu")
SEED = 42
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
<torch._C.Generator at 0x75de937b1190>
print("Python:", sys.version.split()[0])
print("NumPy:", np.__version__)
import plotly
print("Plotly:", plotly.__version__)
print("PyTorch:", torch.__version__)
print("Device:", DEVICE)
Python: 3.12.9
NumPy: 1.26.2
Plotly: 6.5.2
PyTorch: 2.7.0+cu126
Device: cpu
1) TRPO objective and the KL-divergence constraint#
TRPO is usually presented as the constrained optimization problem:
[ \max_\theta; \mathbb{E}{s,a\sim \pi{\theta_{\text{old}}}}\left[\frac{\pi_\theta(a\mid s)}{\pi_{\theta_{\text{old}}}(a\mid s)},\hat A_{\theta_{\text{old}}}(s,a)\right] \qquad\text{s.t.}\qquad \mathbb{E}{s\sim \pi{\theta_{\text{old}}}}\left[D_{\mathrm{KL}}!\left(\pi_{\theta_{\text{old}}}(\cdot\mid s),|,\pi_\theta(\cdot\mid s)\right)\right] \le \delta. ]
The trust region is average KL divergence (under states visited by the old policy). Intuition: “move in a direction that increases the objective, but don’t move too far in policy space.”
We use the standard definition:
[ D_{\mathrm{KL}}(p|q) = \mathbb{E}_{x\sim p}\left[\log\frac{p(x)}{q(x)}\right]. ]
1.1) Why this leads to a natural-gradient step#
Let (\theta) be the policy parameters and (\theta_{\text{old}}) the pre-update parameters.
TRPO uses two approximations around (\theta_{\text{old}}):
First-order (linear) approximation of the surrogate objective:
[ L(\theta) \approx L(\theta_{\text{old}}) + g^\top (\theta - \theta_{\text{old}}) \quad\text{where}\quad g = \nabla_\theta L(\theta)\big\rvert_{\theta=\theta_{\text{old}}}. ]
Second-order (quadratic) approximation of the KL constraint:
[ \bar D_{\mathrm{KL}}(\theta_{\text{old}},\theta) \approx \tfrac12 (\theta - \theta_{\text{old}})^\top H (\theta - \theta_{\text{old}}), ]
where (H) is the Hessian of the average KL at (\theta_{\text{old}}) (equivalently, the policy’s Fisher information matrix for common exponential-family policies).
Define the step (p = \theta - \theta_{\text{old}}). The constrained problem becomes:
[ \max_p; g^\top p \qquad\text{s.t.}\qquad \tfrac12 p^\top H p \le \delta. ]
The solution is:
[ p^* = \sqrt{\frac{2\delta}{g^\top H^{-1} g}}; H^{-1} g. ]
So we need:
The policy-gradient (g)
The product (H^{-1} g) (without forming (H) explicitly) → conjugate gradient + Hessian-vector products
A step scaling + backtracking line search to satisfy the true KL constraint and improve the surrogate.
2) A tiny offline-friendly continuous-control environment#
To keep the notebook self-contained (no Gym downloads), we use a 1D point-mass with state (s=(x,v)) and action (a\in[-1,1]):
Dynamics: small acceleration changes velocity, velocity changes position
Goal: reach (x=0) with small velocity
Reward: negative quadratic cost (plus a small terminal bonus when reaching the goal)
This is not meant to be a benchmark; it’s just enough to show that TRPO learns and that the KL trust region stabilizes updates.
class PointMass1DEnv:
def __init__(
self,
dt: float = 0.05,
max_steps: int = 150,
x_init_range: float = 2.0,
v_init_range: float = 0.5,
action_max: float = 1.0,
goal_x: float = 0.0,
goal_tol: float = 0.05,
goal_bonus: float = 5.0,
seed: int | None = None,
):
self.dt = float(dt)
self.max_steps = int(max_steps)
self.x_init_range = float(x_init_range)
self.v_init_range = float(v_init_range)
self.action_max = float(action_max)
self.goal_x = float(goal_x)
self.goal_tol = float(goal_tol)
self.goal_bonus = float(goal_bonus)
self.rng = np.random.default_rng(seed)
self.steps = 0
self.x = 0.0
self.v = 0.0
@property
def obs_dim(self):
return 2
@property
def act_dim(self):
return 1
def reset(self, seed: int | None = None):
if seed is not None:
self.rng = np.random.default_rng(seed)
self.steps = 0
self.x = self.rng.uniform(-self.x_init_range, self.x_init_range)
self.v = self.rng.uniform(-self.v_init_range, self.v_init_range)
return np.array([self.x, self.v], dtype=np.float32)
def step(self, action):
a = float(np.clip(action, -self.action_max, self.action_max))
# simple damped dynamics
self.v = 0.99 * self.v + a * self.dt
self.x = self.x + self.v * self.dt
self.steps += 1
# quadratic cost around the goal
cost = (self.x - self.goal_x) ** 2 + 0.1 * (self.v**2) + 0.001 * (a**2)
reward = -float(cost)
done = False
if abs(self.x - self.goal_x) < self.goal_tol and abs(self.v) < self.goal_tol:
done = True
reward += float(self.goal_bonus)
if self.steps >= self.max_steps:
done = True
obs = np.array([self.x, self.v], dtype=np.float32)
return obs, reward, done, {}
env = PointMass1DEnv(seed=SEED)
obs = env.reset()
xs, vs, acts, rews = [obs[0]], [obs[1]], [], []
done = False
while not done:
a = rng.uniform(-1.0, 1.0)
obs, r, done, _ = env.step(a)
xs.append(obs[0])
vs.append(obs[1])
acts.append(a)
rews.append(r)
fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(y=xs, mode="lines", name="x"), row=1, col=1)
fig.add_trace(go.Scatter(y=vs, mode="lines", name="v"), row=2, col=1)
fig.add_trace(go.Scatter(y=acts, mode="lines", name="a"), row=3, col=1)
fig.update_layout(
title="One random rollout in the toy env",
height=650,
showlegend=True,
)
fig.update_yaxes(title_text="position x", row=1, col=1)
fig.update_yaxes(title_text="velocity v", row=2, col=1)
fig.update_yaxes(title_text="action a", row=3, col=1)
fig.update_xaxes(title_text="time step", row=3, col=1)
fig.show()
print("Return (sum reward):", float(np.sum(rews)))
Return (sum reward): -206.57660868987693
3) Policy and value function (PyTorch)#
We’ll use:
A Gaussian policy (\pi_\theta(a\mid s)=\mathcal{N}(\mu_\theta(s),\sigma_\theta(s)^2)) with diagonal covariance (here 1D)
A value network (V_\phi(s)) as a baseline
For TRPO we need:
(\log \pi_\theta(a\mid s)) to compute the surrogate objective
The KL between old and new Gaussian policies to build the trust region (and its Hessian-vector product)
def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
layers = []
for i in range(len(sizes) - 1):
act = activation if i < len(sizes) - 2 else output_activation
layers.append(nn.Linear(sizes[i], sizes[i + 1]))
layers.append(act())
return nn.Sequential(*layers)
class GaussianPolicy(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
super().__init__()
self.net = mlp([obs_dim, *hidden_sizes, act_dim], activation=nn.Tanh)
self.log_std = nn.Parameter(torch.zeros(act_dim))
def forward(self, obs: torch.Tensor):
mean = self.net(obs)
log_std = self.log_std.expand_as(mean)
return mean, log_std
def dist(self, obs: torch.Tensor):
mean, log_std = self.forward(obs)
return torch.distributions.Normal(mean, torch.exp(log_std))
@torch.no_grad()
def act(self, obs: torch.Tensor):
dist = self.dist(obs)
action = dist.sample()
logp = dist.log_prob(action).sum(-1)
return action, logp
class ValueNet(nn.Module):
def __init__(self, obs_dim: int, hidden_sizes=(64, 64)):
super().__init__()
self.net = mlp([obs_dim, *hidden_sizes, 1], activation=nn.Tanh)
def forward(self, obs: torch.Tensor):
return self.net(obs).squeeze(-1)
4) TRPO building blocks (low-level)#
We implement:
GAE((\gamma,\lambda)) for advantages
Value function regression
Conjugate gradient for solving (H x = g)
Fisher/Hessian-vector product via autograd on the mean KL
Backtracking line search enforcing the KL constraint
def gaussian_kl(mean_old, log_std_old, mean_new, log_std_new):
"""KL( N_old || N_new ) for diagonal Gaussians; returns shape (batch,)."""
var_old = torch.exp(2.0 * log_std_old)
var_new = torch.exp(2.0 * log_std_new)
kl_per_dim = (
log_std_new
- log_std_old
+ (var_old + (mean_old - mean_new) ** 2) / (2.0 * var_new)
- 0.5
)
return kl_per_dim.sum(dim=-1)
def flat_params(model: nn.Module):
return torch.cat([p.data.view(-1) for p in model.parameters()])
def set_flat_params(model: nn.Module, flat: torch.Tensor):
idx = 0
with torch.no_grad():
for p in model.parameters():
n = p.numel()
p.copy_(flat[idx : idx + n].view_as(p))
idx += n
def flat_grad(grads, params):
out = []
for g, p in zip(grads, params):
if g is None:
out.append(torch.zeros_like(p).view(-1))
else:
out.append(g.contiguous().view(-1))
return torch.cat(out)
def conjugate_gradient(fvp_fn, b, cg_iters=10, residual_tol=1e-10):
x = torch.zeros_like(b)
r = b.clone()
p = b.clone()
rdotr = torch.dot(r, r)
for _ in range(cg_iters):
Avp = fvp_fn(p)
alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
x = x + alpha * p
r = r - alpha * Avp
new_rdotr = torch.dot(r, r)
if new_rdotr < residual_tol:
break
beta = new_rdotr / (rdotr + 1e-8)
p = r + beta * p
rdotr = new_rdotr
return x
def trpo_update(
policy: GaussianPolicy,
obs: torch.Tensor,
act: torch.Tensor,
adv: torch.Tensor,
logp_old: torch.Tensor,
max_kl: float = 0.01,
cg_iters: int = 10,
cg_damping: float = 1e-2,
backtrack_iters: int = 10,
backtrack_coeff: float = 0.8,
):
"""One TRPO policy update step."""
params = list(policy.parameters())
old_params = flat_params(policy)
with torch.no_grad():
mean_old, log_std_old = policy.forward(obs)
mean_old = mean_old.detach()
log_std_old = log_std_old.detach()
def surrogate():
dist = policy.dist(obs)
logp = dist.log_prob(act).sum(-1)
ratio = torch.exp(logp - logp_old)
return (ratio * adv).mean()
def mean_kl():
mean_new, log_std_new = policy.forward(obs)
return gaussian_kl(mean_old, log_std_old, mean_new, log_std_new).mean()
surr = surrogate()
g = torch.autograd.grad(surr, params, retain_graph=True, allow_unused=True)
g_flat = flat_grad(g, params).detach()
def fvp(v):
kl = mean_kl()
grads = torch.autograd.grad(kl, params, create_graph=True, allow_unused=True)
flat_kl_grad = flat_grad(grads, params)
kl_v = torch.dot(flat_kl_grad, v)
grads2 = torch.autograd.grad(kl_v, params, allow_unused=True)
hvp = flat_grad(grads2, params).detach()
return hvp + cg_damping * v
step_dir = conjugate_gradient(fvp, g_flat, cg_iters=cg_iters)
shs = torch.dot(step_dir, fvp(step_dir))
step_size = torch.sqrt(torch.tensor(2.0 * max_kl, dtype=shs.dtype) / (shs + 1e-8))
full_step = step_dir * step_size
def eval_surr_and_kl():
with torch.no_grad():
s = surrogate().item()
k = mean_kl().item()
return s, k
surr_old_val, _ = eval_surr_and_kl()
step_frac = 1.0
accepted = False
surr_new_val = surr_old_val
kl_new_val = 0.0
for _ in range(backtrack_iters):
new_params = old_params + step_frac * full_step
set_flat_params(policy, new_params)
surr_new_val, kl_new_val = eval_surr_and_kl()
if (surr_new_val > surr_old_val) and (kl_new_val <= max_kl):
accepted = True
break
step_frac *= backtrack_coeff
if not accepted:
set_flat_params(policy, old_params)
return {
"surr_old": float(surr_old_val),
"surr_new": float(surr_new_val),
"kl": float(kl_new_val),
"step_frac": float(step_frac if accepted else 0.0),
"accepted": bool(accepted),
}
def collect_batch(env, policy, value_net, steps_per_batch, gamma=0.99, lam=0.98):
obs_buf = np.zeros((steps_per_batch, env.obs_dim), dtype=np.float32)
act_buf = np.zeros((steps_per_batch, env.act_dim), dtype=np.float32)
rew_buf = np.zeros(steps_per_batch, dtype=np.float32)
done_buf = np.zeros(steps_per_batch, dtype=np.float32)
val_buf = np.zeros(steps_per_batch, dtype=np.float32)
logp_buf = np.zeros(steps_per_batch, dtype=np.float32)
ep_returns = []
ep_ret = 0.0
obs = env.reset()
for t in range(steps_per_batch):
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
with torch.no_grad():
a_t, logp_t = policy.act(obs_t)
v_t = value_net(obs_t)
a = a_t.squeeze(0).cpu().numpy()
logp = float(logp_t.item())
v = float(v_t.item())
next_obs, r, done, _ = env.step(a)
obs_buf[t] = obs
act_buf[t] = a
rew_buf[t] = r
done_buf[t] = float(done)
val_buf[t] = v
logp_buf[t] = logp
ep_ret += float(r)
obs = next_obs
if done:
ep_returns.append(ep_ret)
ep_ret = 0.0
obs = env.reset()
# bootstrap value for the last state (if last transition wasn't terminal)
with torch.no_grad():
last_val = value_net(
torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
).item()
adv_buf = np.zeros(steps_per_batch, dtype=np.float32)
last_gae = 0.0
for t in reversed(range(steps_per_batch)):
if t == steps_per_batch - 1:
next_nonterminal = 1.0 - done_buf[t]
next_value = last_val
else:
next_nonterminal = 1.0 - done_buf[t]
next_value = val_buf[t + 1]
delta = rew_buf[t] + gamma * next_value * next_nonterminal - val_buf[t]
last_gae = delta + gamma * lam * next_nonterminal * last_gae
adv_buf[t] = last_gae
ret_buf = adv_buf + val_buf
# normalize advantages (very common and usually helpful)
adv_buf = (adv_buf - adv_buf.mean()) / (adv_buf.std() + 1e-8)
batch = {
"obs": torch.as_tensor(obs_buf, dtype=torch.float32, device=DEVICE),
"act": torch.as_tensor(act_buf, dtype=torch.float32, device=DEVICE),
"logp_old": torch.as_tensor(logp_buf, dtype=torch.float32, device=DEVICE),
"adv": torch.as_tensor(adv_buf, dtype=torch.float32, device=DEVICE),
"ret": torch.as_tensor(ret_buf, dtype=torch.float32, device=DEVICE),
"ep_returns": ep_returns,
}
return batch
# --- Run configuration ---
FAST_RUN = True # set False for a longer run
TOTAL_ITERS = 25 if FAST_RUN else 150
STEPS_PER_BATCH = 1024 if FAST_RUN else 4096
GAMMA = 0.99
LAMBDA = 0.98
MAX_KL = 0.01
CG_ITERS = 10
CG_DAMPING = 1e-2
BACKTRACK_ITERS = 10
BACKTRACK_COEFF = 0.8
VF_LR = 3e-4
VF_ITERS = 10 if FAST_RUN else 80
VF_BATCH = 128
SNAPSHOT_EVERY = 5
env = PointMass1DEnv(seed=SEED)
policy = GaussianPolicy(env.obs_dim, env.act_dim, hidden_sizes=(64, 64)).to(DEVICE)
value_net = ValueNet(env.obs_dim, hidden_sizes=(64, 64)).to(DEVICE)
vf_optim = torch.optim.Adam(value_net.parameters(), lr=VF_LR)
x_grid = np.linspace(-env.x_init_range, env.x_init_range, 101, dtype=np.float32)
history = {
"iter": [],
"ep_ret_mean": [],
"ep_ret_p10": [],
"ep_ret_p90": [],
"kl": [],
"surr_old": [],
"surr_new": [],
"step_frac": [],
"policy_std": [],
}
policy_snapshots = []
t0 = time.time()
for it in range(TOTAL_ITERS):
batch = collect_batch(
env,
policy,
value_net,
steps_per_batch=STEPS_PER_BATCH,
gamma=GAMMA,
lam=LAMBDA,
)
# --- Fit value function ---
for _ in range(VF_ITERS):
n = batch["obs"].shape[0]
bs = min(VF_BATCH, n)
idx = torch.as_tensor(rng.choice(n, size=bs, replace=False), device=DEVICE)
v_pred = value_net(batch["obs"][idx])
v_loss = F.mse_loss(v_pred, batch["ret"][idx])
vf_optim.zero_grad()
v_loss.backward()
vf_optim.step()
# --- TRPO policy update ---
stats = trpo_update(
policy,
obs=batch["obs"],
act=batch["act"],
adv=batch["adv"],
logp_old=batch["logp_old"],
max_kl=MAX_KL,
cg_iters=CG_ITERS,
cg_damping=CG_DAMPING,
backtrack_iters=BACKTRACK_ITERS,
backtrack_coeff=BACKTRACK_COEFF,
)
# --- Metrics ---
ep_returns = batch["ep_returns"]
if len(ep_returns) > 0:
ep_mean = float(np.mean(ep_returns))
ep_p10 = float(np.percentile(ep_returns, 10))
ep_p90 = float(np.percentile(ep_returns, 90))
else:
ep_mean, ep_p10, ep_p90 = float("nan"), float("nan"), float("nan")
with torch.no_grad():
policy_std = float(torch.exp(policy.log_std).mean().item())
history["iter"].append(it)
history["ep_ret_mean"].append(ep_mean)
history["ep_ret_p10"].append(ep_p10)
history["ep_ret_p90"].append(ep_p90)
history["kl"].append(stats["kl"])
history["surr_old"].append(stats["surr_old"])
history["surr_new"].append(stats["surr_new"])
history["step_frac"].append(stats["step_frac"])
history["policy_std"].append(policy_std)
# snapshot policy mean(action|x,v=0) over a grid
if (it == 0) or (it % SNAPSHOT_EVERY == 0) or (it == TOTAL_ITERS - 1):
obs_grid = np.stack([x_grid, np.zeros_like(x_grid)], axis=1)
with torch.no_grad():
mu, _ = policy.forward(torch.as_tensor(obs_grid, dtype=torch.float32, device=DEVICE))
policy_snapshots.append({"iter": it, "mu": mu.squeeze(-1).cpu().numpy()})
if (it + 1) % max(1, TOTAL_ITERS // 5) == 0 or it == 0:
print(
f"iter {it:03d} | ep_ret_mean {ep_mean:8.2f} | KL {stats['kl']:.4f} | "
f"step_frac {stats['step_frac']:.3f} | std {policy_std:.3f}"
)
print(f"Done in {time.time() - t0:.2f}s")
/tmp/ipykernel_1016759/1044679479.py:45: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:824: UserWarning:
CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
iter 000 | ep_ret_mean -1083.82 | KL 0.0086 | step_frac 1.000 | std 0.997
iter 004 | ep_ret_mean -333.98 | KL 0.0092 | step_frac 1.000 | std 0.948
iter 009 | ep_ret_mean -117.50 | KL 0.0070 | step_frac 0.800 | std 0.940
iter 014 | ep_ret_mean -2175.23 | KL 0.0068 | step_frac 0.640 | std 1.018
iter 019 | ep_ret_mean -2436.67 | KL 0.0049 | step_frac 1.000 | std 1.005
iter 024 | ep_ret_mean -63.65 | KL 0.0089 | step_frac 1.000 | std 1.044
Done in 3.64s
# Plotly: learning curves and trust-region diagnostics
iters = history["iter"]
fig = make_subplots(
rows=3,
cols=1,
shared_xaxes=True,
vertical_spacing=0.08,
subplot_titles=(
"Episodic return (mean + 10/90 percentile band)",
"Mean KL(old || new) per update (should be ≤ max_kl)",
"Policy std (exp(log_std))",
),
)
# return band
fig.add_trace(
go.Scatter(x=iters, y=history["ep_ret_p90"], mode="lines", line=dict(width=0), showlegend=False),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=iters,
y=history["ep_ret_p10"],
mode="lines",
fill="tonexty",
line=dict(width=0),
name="p10–p90",
opacity=0.25,
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(x=iters, y=history["ep_ret_mean"], mode="lines+markers", name="mean"),
row=1,
col=1,
)
# KL curve
fig.add_trace(
go.Scatter(x=iters, y=history["kl"], mode="lines+markers", name="KL"),
row=2,
col=1,
)
fig.add_hline(y=MAX_KL, line_dash="dash", line_color="black", row=2, col=1)
# policy std
fig.add_trace(
go.Scatter(x=iters, y=history["policy_std"], mode="lines+markers", name="std"),
row=3,
col=1,
)
fig.update_layout(height=850, title="TRPO learning diagnostics")
fig.update_xaxes(title_text="iteration", row=3, col=1)
fig.update_yaxes(title_text="return", row=1, col=1)
fig.update_yaxes(title_text="KL", row=2, col=1)
fig.update_yaxes(title_text="std", row=3, col=1)
fig.show()
# Plotly: how the policy mean changes over iterations
fig = go.Figure()
for snap in policy_snapshots:
fig.add_trace(
go.Scatter(
x=x_grid,
y=snap["mu"],
mode="lines",
name=f"iter {snap['iter']}",
)
)
fig.update_layout(
title="Policy mean action μ(x, v=0) snapshots",
xaxis_title="position x (with v fixed at 0)",
yaxis_title="mean action μ",
height=450,
)
fig.show()
5) Stable-Baselines TRPO (reference implementation)#
TRPO does exist in the original stable-baselines (TensorFlow) project via stable_baselines.trpo_mpi.TRPO (and is re-exported as stable_baselines.TRPO if mpi4py is installed).
Example usage (not executed here):
import gym
# Requires the original stable-baselines (TensorFlow) + mpi4py.
from stable_baselines import TRPO
from stable_baselines.common.policies import MlpPolicy
env = gym.make("CartPole-v1")
model = TRPO(
MlpPolicy,
env,
gamma=0.99,
timesteps_per_batch=1024,
max_kl=0.01,
cg_iters=10,
lam=0.98,
entcoeff=0.0,
cg_damping=1e-2,
vf_stepsize=3e-4,
vf_iters=3,
verbose=1,
)
model.learn(total_timesteps=200_000)
Source used to verify signature and defaults:
https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/trpo_mpi/trpo_mpi.py
Stable-Baselines TRPO hyperparameters (what they mean)#
From the upstream TRPO.__init__ signature:
gamma— discount factor (\gamma)timesteps_per_batch— on-policy batch size (number of environment steps collected before each TRPO update)max_kl— trust-region radius (\delta): target/upper bound on mean KL(old || new)cg_iters— number of conjugate-gradient iterations used to approximately solve (H x = g)lam— GAE parameter (\lambda) controlling bias/variance tradeoff in advantagesentcoeff— entropy bonus coefficient (encourages exploration by penalizing low entropy)cg_damping— adds a small multiple of the identity to the Fisher/Hessian-vector product for numerical stabilityvf_stepsize— learning rate for the value function optimizervf_iters— number of value-function optimization iterations per updatetensorboard_log/full_tensorboard_log— logging configurationpolicy_kwargs— extra arguments passed to the policy network constructorseed— RNG seedn_cpu_tf_sess— TensorFlow session CPU threading configuration
A good way to tune TRPO is to start with:
max_klaround0.01and adjust up/down for faster learning vs stabilitytimesteps_per_batchlarger for smoother updates (at higher compute cost)cg_dampingslightly larger if updates become numerically unstable